from torch import nn

from . import lenet

__all__ = ["lenet"]

__model_list = {
    "lenet": lenet.LeNet
}


def get_model_names():
    return __model_list.keys()


def get_model(model_name: str, *args, **kwargs) -> nn.Module:
    """
    Get the modeldef by modeldef name.
    :param model_name: The name of the modeldef
    :param args: The args passed to the modeldef constructor
    :param kwargs: The kwargs passed to the modeldef constructor
    :return: The desired modeldef.
    """
    return __model_list[model_name](*args, **kwargs)
